from tensorflow import keras

class TextCNNSmall():
    """docstring for TextCNNSmall"""
    def __init__(self, model_conf, train_conf={"batch_size":64,"epochs":3, "verbose":1}):
        self.model_conf = model_conf
        self.train_conf = train_conf
        self.__build_structure__()


    def __build_structure__(self):
        inputs = keras.layers.Input(shape=(self.model_conf["MAX_LEN"],))
        embedding_layer = keras.layers.Embedding(output_dim = self.model_conf["w2c_len"],
                                    input_dim = len(self.model_conf["emb_model"].embedding_weights), 
                                    weights=[self.model_conf["emb_model"].get_np_weights()], 
                                    input_length=self.model_conf["MAX_LEN"], 
                                    trainable=True
                                    )
        x = embedding_layer(inputs)

        l_conv1 = keras.layers.Conv1D(filters=self.model_conf["w2c_len"], kernel_size=3, activation='relu')(x)  
        l_pool1 = keras.layers.MaxPool1D(pool_size=3)(l_conv1)
        l_pool11 = keras.layers.Flatten()(l_pool1)

        out = keras.layers.Dropout(0.5)(l_pool11)
        output = keras.layers.Dense(32, activation='relu')(out)
         
        pred = keras.layers.Dense(units=1, activation='sigmoid')(output)
         
        self.model = keras.models.Model(inputs=inputs, outputs=pred)
        # adam = optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
        self.model.summary()
        self.model.compile(loss="binary_crossentropy", optimizer="adam", metrics=['accuracy'])

    def fit(self, x_train, y_train, x_test, y_test):
        history = self.model.fit(x_train, y_train, batch_size=self.train_conf.get("batch_size", 64),
                                epochs=self.train_conf.get("epochs", 3),
                                validation_data=(x_test, y_test),
                                verbose=self.train_conf.get("verbose", 1))
        return history

    def evaluate(self, x_test, y_test):
        return self.model.evaluate(x_test, y_test)

    def predict(self, sentences):
        return self.model.predict(sentences)

    def save(self, path):
        if self.model:
            self.model.save(path)

    def load(self, path):
        self.model = keras.load_model(path)
